import sys, time
import numpy as np
from OpenGL.GL import *
from OpenGL.GLUT import *
from OpenGL.GL.shaders import compileProgram, compileShader

# ---------- Globals ----------
window = None
shader = None
vao = None
textures = []
fbos = []
current = 0
cycle = 0.0
omega_time = 0.0

# ---------- Lattice params ----------
lattice_width = 256      # maximize parallel fragments
num_instances_base = 2_000_000  # push VRAM
sub_tile_height = 256
max_tex_height = 2048
threshold = np.sqrt(1.6180339887)
phi = 1.6180339887

num_Dn = 32  # maximum Base(∞) primitives per slot

# Precompute Dn(r) lookup table on CPU
r_steps = 256
Dn_table = np.zeros((num_Dn, r_steps), dtype=np.float32)
F = [1,1,2,3,5,8,13,21,34,55,89,144,233,377,610,987,
     1597,2584,4181,6765,10946,17711,28657,46368,75025,121393,196418,317811,514229,832040,1346269,2178309]
P = [2,3,5,7,11,13,17,19,23,29,31,37,41,43,47,53,
     59,61,67,71,73,79,83,89,97,101,103,107,109,113,127,131,137]
for n in range(num_Dn):
    Dn_table[n] = np.sqrt(phi * F[n] * 2**(n+1) * P[n]) * (np.linspace(0.0,1.0,r_steps)**(n+1))

tile_heights = []
tile_count = 0

# ---------- Vertex Shader ----------
VERTEX_SRC = """
#version 330
layout(location = 0) in vec2 pos;
out vec2 texCoord;
void main(){
    texCoord = (pos + 1.0)*0.5;
    gl_Position = vec4(pos,0,1);
}
"""

# ---------- Fragment Shader ----------
FRAGMENT_SRC = """
#version 330
in vec2 texCoord;
out vec4 fragColor;

uniform sampler2D latticeTex;
uniform sampler2D DnTable;
uniform float cycle;
uniform float omegaTime;
uniform float threshold;
uniform int latticeHeight;
uniform int yOffset;

float hybrid_slot(float val, int x, int y){
    int r_idx = int(float(y)/float(latticeHeight)*255.0);
    float accum = val;
    for(int n=0; n<32; n++){
        float slot = texelFetch(DnTable, ivec2(n,r_idx),0).r;
        float wave = mod(float(n),3.0)==0.0?0.3:(mod(float(n),3.0)==1.0?0.0:-0.3);
        accum += slot + wave;
    }
    return accum > threshold ? 1.0 : 0.0;
}

void main(){
    int x = int(texCoord.x * 256.0);
    int y = int(texCoord.y * float(latticeHeight)) + yOffset;
    float val = texelFetch(latticeTex, ivec2(x,y-yOffset),0).r;
    float new_val = hybrid_slot(val,x,y);
    fragColor = vec4(new_val, sin(omegaTime), cos(omegaTime*0.5),1.0);
}
"""

# ---------- OpenGL Initialization ----------
def init_gl():
    global shader, vao, textures, fbos, tile_heights, tile_count, Dn_table_tex

    shader = compileProgram(compileShader(VERTEX_SRC, GL_VERTEX_SHADER),
                            compileShader(FRAGMENT_SRC, GL_FRAGMENT_SHADER))

    verts = np.array([-1,-1,1,-1,-1,1,1,-1,1,1,-1,1],dtype=np.float32)
    vao = glGenVertexArrays(1)
    glBindVertexArray(vao)
    vbo = glGenBuffers(1)
    glBindBuffer(GL_ARRAY_BUFFER, vbo)
    glBufferData(GL_ARRAY_BUFFER, verts.nbytes, verts, GL_STATIC_DRAW)
    glVertexAttribPointer(0,2,GL_FLOAT,GL_FALSE,0,None)
    glEnableVertexAttribArray(0)

    # DnTable texture
    Dn_table_tex = glGenTextures(1)
    glBindTexture(GL_TEXTURE_2D, Dn_table_tex)
    glTexImage2D(GL_TEXTURE_2D,0,GL_R16F, num_Dn, r_steps,0,GL_RED,GL_FLOAT,Dn_table)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
    glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)

    reinit_lattice(num_instances_base)

# ---------- Lattice Reinit ----------
def reinit_lattice(new_num_instances):
    global textures, fbos, tile_heights, tile_count
    for tex_pair in textures:
        glDeleteTextures(tex_pair)
    for fbo_pair in fbos:
        glDeleteFramebuffers(2, fbo_pair)
    textures.clear()
    fbos.clear()

    tile_count = (new_num_instances + max_tex_height - 1)//max_tex_height
    tile_heights[:] = [min(max_tex_height, new_num_instances - i*max_tex_height) for i in range(tile_count)]

    for th in tile_heights:
        tex_pair = glGenTextures(2)
        fbo_pair = glGenFramebuffers(2)
        for i in range(2):
            glBindTexture(GL_TEXTURE_2D, tex_pair[i])
            glTexImage2D(GL_TEXTURE_2D,0,GL_RGBA16F,lattice_width,th,0,GL_RGBA,GL_HALF_FLOAT,None)
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_NEAREST)
            glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_NEAREST)
            glBindFramebuffer(GL_FRAMEBUFFER, fbo_pair[i])
            glFramebufferTexture2D(GL_FRAMEBUFFER, GL_COLOR_ATTACHMENT0, GL_TEXTURE_2D, tex_pair[i],0)
        textures.append(tex_pair)
        fbos.append(fbo_pair)
        init_chunk = np.zeros((th,lattice_width,4),dtype=np.float32)
        glBindTexture(GL_TEXTURE_2D, tex_pair[0])
        glTexSubImage2D(GL_TEXTURE_2D,0,0,0,lattice_width,th,GL_RGBA,GL_FLOAT,init_chunk)
    glBindFramebuffer(GL_FRAMEBUFFER,0)

# ---------- Display ----------
def display():
    global cycle, omega_time, current
    next_idx = 1-current
    for t, th in enumerate(tile_heights):
        glBindFramebuffer(GL_FRAMEBUFFER, fbos[t][next_idx])
        glViewport(0,0,lattice_width,th)
        glUseProgram(shader)
        glUniform1i(glGetUniformLocation(shader,"latticeTex"),0)
        glUniform1i(glGetUniformLocation(shader,"DnTable"),1)
        glUniform1f(glGetUniformLocation(shader,"cycle"),cycle)
        glUniform1f(glGetUniformLocation(shader,"omegaTime"),omega_time)
        glUniform1f(glGetUniformLocation(shader,"threshold"),threshold)
        glUniform1i(glGetUniformLocation(shader,"latticeHeight"), th)
        glUniform1i(glGetUniformLocation(shader,"yOffset"),0)

        glActiveTexture(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE_2D, textures[t][current])
        glActiveTexture(GL_TEXTURE1)
        glBindTexture(GL_TEXTURE_2D, Dn_table_tex)

        glBindVertexArray(vao)
        glDrawArrays(GL_TRIANGLES,0,6)

    glBindFramebuffer(GL_FRAMEBUFFER,0)
    glViewport(0,0,1280,720)
    for t, th in enumerate(tile_heights):
        glBindTexture(GL_TEXTURE_2D, textures[t][next_idx])
        glDrawArrays(GL_TRIANGLES,0,6)

    glutSwapBuffers()
    cycle += 1
    omega_time += 0.05
    current = next_idx

# ---------- Idle ----------
def idle():
    glutPostRedisplay()

# ---------- Main ----------
def main():
    glutInit(sys.argv)
    glutInitDisplayMode(GLUT_RGBA | GLUT_DOUBLE)
    glutInitWindowSize(1280,720)
    glutCreateWindow(b"HDGL RX480 Max Saturation Hybrid Lattice")
    init_gl()
    glutDisplayFunc(display)
    glutIdleFunc(idle)
    glutMainLoop()

if __name__=="__main__":
    main()